from code_template import CodeTemplate
from function_wrapper import nested_dict

FILE = CodeTemplate("""\
#include "ATen/Config.h"

#include "TH/TH.h"
#if AT_CUDA_ENABLED()
#undef THNN_
#include "THC/THC.h"
#endif
#include "ATen/Utils.h"
${copy_includes}

namespace at {

${copy_functions}

}
""")

COPY = CodeTemplate("""\
${THTensor}_copy${cuda}${src_scalar_name}(${state,}self_->tensor, static_cast<${src_tensor}*>(src.pImpl)->tensor);
""")

COPY_ASYNC_CPU = CodeTemplate("""\
if (async) {
    ${THTensor}_copyAsyncCPU(${state,}self_->tensor, static_cast<${src_tensor}*>(src.pImpl)->tensor);
    break;
}
""")

COPY_ASYNC_CUDA = CodeTemplate("""\
if (async) {
    ${THTensor}_copyAsyncCuda(${state,}self_->tensor, static_cast<${src_tensor}*>(src.pImpl)->tensor);
    break;
}
""")

CASE = CodeTemplate("""\
case ${src_id}:
    ${copies}
    break;
""")

FUNCTION = CodeTemplate("""\
Tensor & ${Type}::s_copy_(Tensor & self, const Tensor & src, bool async) const {
  // code generated by function_wrapper
  auto self_ = checked_cast_tensor<${Tensor}>(self.pImpl, "self", 0,false);
  (void) self_; //silence unused warning
  switch (src.type().ID()) {
    ${copy_body}
    default:
      runtime_error("copy does not support %s to %s copy.", src.type().toString(), toString());
      break;
  }
  self.pImpl->setScalar(src.pImpl->isScalar());
  return self;
}
""")


def create_one(env, all_types):
    copy_body = []
    for src_type in all_types:
        if env['Density'] == 'Sparse' or src_type['Density'] == 'Sparse':
            # skip sparse copies, which are not yet implemented
            continue
        state = []
        cuda = ''
        if src_type['Backend'] == 'CUDA':
            cuda = 'Cuda'
        if env['Backend'] == 'CUDA' or src_type['Backend'] == 'CUDA':
            state.append('context->thc_state')

        combined = nested_dict({
            'src_scalar_name': src_type['ScalarName'],
            'src_id': src_type['TypeID'],
            'src_tensor': src_type['Tensor'],
            'cuda': cuda,
            'state': state,
        }, env)

        copies = []
        if env['ScalarType'] == src_type['ScalarType']:
            if env['Backend'] == 'CUDA' and src_type['Backend'] == 'CPU':
                copies.append(COPY_ASYNC_CPU.substitute(combined))
            if env['Backend'] == 'CPU' and src_type['Backend'] == 'CUDA':
                copies.append(COPY_ASYNC_CUDA.substitute(combined))
        copies.append(COPY.substitute(combined))

        copy_body.append(CASE.substitute(combined, copies=copies))
    return FUNCTION.substitute(env, copy_body=copy_body)


def create(all_types):

    top_env = {
        'copy_includes': [],
        'copy_functions': [],
    }
    for dst_type in all_types:
        top_env['copy_includes'].append(
            '#include "ATen/{}.h"'.format(dst_type['Type']))
        top_env['copy_includes'].append(
            '#include "ATen/{}.h"'.format(dst_type['Tensor']))
        top_env['copy_functions'].append(create_one(dst_type, all_types))
    return FILE.substitute(top_env)
